# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""System test: test reducewindow2d"""
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

import pytest
import tensorflow as tf
import numpy as np
from tflite_test_runner import TfliteTestRunner


def _make_module(n, i_channels, i_size, k_size, strides, padding, pad):
    class ReduceWindow2DModule(tf.Module):
        def __init__(self):
            super(ReduceWindow2DModule).__init__()

        @tf.function(input_signature=[tf.TensorSpec([n, *i_size, i_channels], tf.float32)])
        def __call__(self, x):
            outs = []
            outs.append(tf.nn.max_pool2d(tf.pad(x, tf.constant(pad)), k_size, strides, padding))
            outs.append(tf.nn.avg_pool2d(tf.pad(x, tf.constant(pad)), k_size, strides, padding))
            return outs
    return ReduceWindow2DModule()


n = [
    1,
    3
]

i_channels = [
    1,
    16
]

i_sizes = [
    [112, 112]
]

k_sizes = [
    [3, 3],
    [5, 5]
]

strides = [
    [1, 1],
    [2, 2]
]

paddings = [
    'SAME',
    'VALID'
]

pads = [
    [[0, 0], [1, 0], [1, 1], [0, 0]],
    [[0, 0], [1, 1], [1, 1], [0, 0]]
]


@pytest.mark.parametrize('n', n)
@pytest.mark.parametrize('i_channels', i_channels)
@pytest.mark.parametrize('i_size', i_sizes)
@pytest.mark.parametrize('k_size', k_sizes)
@pytest.mark.parametrize('strides', strides)
@pytest.mark.parametrize('padding', paddings)
@pytest.mark.parametrize('pad', pads)
def test_pad_reduce_window2d(n, i_channels, i_size, k_size, strides, padding, pad, request):
    if padding != 'VALID' or (k_size[0] <= i_size[0] and k_size[1] <= i_size[1]):
        module = _make_module(n, i_channels, i_size, k_size,
                              strides, padding, pad)

        runner = TfliteTestRunner(request.node.name)
        model_file = runner.from_tensorflow(module)
        runner.run(model_file)


if __name__ == "__main__":
    pytest.main(['-vv', 'test_pad_reduce_window2d.py'])
